import torch
import torch.nn as nn
import torch.optim as optim

#定义模型
model = nn.Linear(10, 1)

#定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr = 0.01)

#训练数据

#保存模型
MODEL_PATH = "./model02.pth"
#获取网络模型当前参数
state = model.state_dict()
torch.save(model.state_dict(), MODEL_PATH)

#加载模型
model.load_state_dict(torch.load(MODEL_PATH))

#TODO 如何使用模型进行预测